import numpy as np
import pandas as pd
import os
import sys
import pickle
from joblib import dump, load
import yaml
import statsmodels.api as sm
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RandomizedSearchCV
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.linear_model import LinearRegression as reg
import pdb
from pprint import pprin
from statistics import mean
import random
random.seed(50)
import math
import argparse

sys.path.insert(1,'/REDACTED/fairness/code/rf/scripts/rf')
import tune_utils as tu
from gpd_test import *



###############################################
#### GET PLOT DICTIONARIES
###############################################


###############################################
#### ORIGINAL REG, CLS, ORACLE MODELS
###############################################


## eitc cls 100 with dep_database features

get_plot_dict_new(acsource='return',
                eitc=True,
                rand_seed=50,
                bootstrap=False,
                bootstrap_iters=100,
                model='cls',
                retrain=False,
                thresh=100,
                dep_database=True,
                write_test=False,
                one_fold_params=False,
                fold_for_params=0,
                datapath='/REDACTED/fairness/code/rf/data/',
                outdir='/REDACTED/audit_descr_stats/data/final_paper_figs')

## eitc rf reg with dep_database features
get_plot_dict_new(acsource='return',
                eitc=True,
                rand_seed=50,
                bootstrap=True,
                bootstrap_iters=100,
                model='reg',
                retrain=False,
                thresh=100,
                dep_database=True,
                write_test=False,
                one_fold_params=False,
                fold_for_params=0,
                datapath='/REDACTED/fairness/code/rf/data/',
                outdir='/REDACTED/data/modeled_refactor_temp/')

    ## eitc oracle
get_plot_dict_new(acsource='return',
                eitc=True,
                rand_seed=50,
                bootstrap=True,
                bootstrap_iters=100,
                model='oracle',
                retrain=False,
                thresh=None,
                one_fold_params=False,
                datapath='/REDACTED/fairness/code/rf/data/',
                outdir='/REDACTED/data/modeled_refactor_temp/')

###############################################
#### Different Classifier Thresholds
###############################################

thresh_list = [1,50,100,500,1000, 2000]
for i in thresh_list:
        get_plot_dict_new(acsource='return',
                eitc=True,
		activity_code = None,
                rand_seed=50,
                bootstrap=True,
                bootstrap_iters=100,
                model='cls',
                retrain=False,
                thresh=int(i),
                dep_database=True,
                write_test=False,
                one_fold_params=False,
                fold_for_params=None,
                datapath='/REDACTED/fairness/code/rf/data/',
                outdir='/REDACTED/data/modeled_refactor_temp/')


###############################################
### REFUNDABLE CREDIT 
###############################################

sys.path.insert(1,'/REDACTED/fairness/code/rf/scripts/rf')
from gpd_test_alt_outcome import *
get_plot_dict_new(acsource='return',
                eitc=True,
                activity_code = None,
                rand_seed=50,
                bootstrap=True,
                bootstrap_iters=100,
                model='reg',
                retrain=True,
                thresh=None,
                dep_database=True,
                write_test=False,
                one_fold_params=False,
                fold_for_params=None,
                datapath='/REDACTED/fairness/code/rf/data/',
                outdir='/REDACTED/data/modeled_refactor_temp/',
                outcome = 'ref_cred_amt_dif_pv')



###############################################
### SELECtaxpayer_idG AUDITS BASED ON REVENUE (INCORPORAtaxpayer_idG AUDIT COSTS) 
###############################################

sys.path.insert(1,'/REDACTED/fairness/code/rf/scripts/rf')
from gpd_test import *
get_plot_dict_new(acsource='return',
                eitc=True,
                activity_code = None,
                rand_seed=50,
                bootstrap=True,
                bootstrap_iters=100,
                model='reg',
                retrain=True,
                thresh=None,
                dep_database=True,
                write_test=False,
                one_fold_params=False,
                fold_for_params=None,
                datapath='/REDACTED/fairness/code/rf/data/',
                outdir='/REDACTED/data/modeled_refactor_temp/net_rev/',
                net_rev_version = True)

get_plot_dict_new(acsource='return',
                eitc=True,
                activity_code = None,
                rand_seed=50,
                bootstrap=True,
                bootstrap_iters=100,
                model='oracle',
                retrain=True,
                thresh=None,
                dep_database=True,
                write_test=False,
                one_fold_params=False,
                fold_for_params=None,
                datapath='/REDACTED/fairness/code/rf/data/',
                outdir='/REDACTED/data/modeled_refactor_temp/net_rev/',
                net_rev_version = True)


##############################################################
"""
###############################################
### NOT TRAINING ON NONRESPONSES (FINAL_WGT VARIABLE)
### PLOT NOT USED 
###############################################

sys.path.insert(1,'/REDACTED/fairness/code/rf/scripts/rf')
import tune_utils as tu

### get train and test data for each fold
tu.k_fold_split(random_state=50,
                 n_splits=5,
                 eitc=True,
                 dep_database=True,
                 datapath='/REDACTED/fairness/code/rf/data/final_wgt_data/')

### NON RESPONSES
for fold in range(5):
        tu.tune_model(train_fold=fold, 
                eitc=True,
                dep_database=True,
                datapath='/REDACTED/fairness/code/rf/data/final_wgt_data/',
                model_type= 'reg', ## options are 'reg', 'cls'
                threshold= None, ## options (for now) are None or 100
                ts=0.25, ## test size
                cvs=5,
                njs=10)


research_audits_clean = pd.read_csv('/REDACTED/fairness/code/rf/data/clean_rf_data.csv')
research_audits_clean_dep_database = pd.read_csv('/REDACTED/fairness/code/rf/data/clean_rf_data_plus_dep_database.csv')
research_audits_wgt_var = pd.read_csv('/REDACTED/data/clean/research_audits_wfinest_race.csv')


for df in [research_audits_clean, research_audits_clean_dep_database, research_audits_wgt_var]:
    df.tax_period = df.tax_period.astype(int)
    df['tax_yr'] = [math.floor(x/100) for x in df.tax_period]
    ##df = df.drop(columns=['tax_period'])


research_audits_wgt_var = research_audits_wgt_var[['taxpayer_id', 'tax_yr', 'final_wgt']]
research_audits_wgt_var = research_audits_wgt_var.drop_duplicates(subset = ['taxpayer_id', 'tax_yr'])

overlap = [x for x in research_audits_clean.columns if x in research_audits_wgt_var.columns]

research_audits_merge = research_audits_clean.merge(research_audits_wgt_var, how = 'left', on = overlap)


##research_audits_merge[research_audits_merge['final_wgt']==0].chg_in_tax_owed.value_counts()
##research_audits_merge[research_audits_merge.final_wgt.isna()].chg_in_tax_owed.value_counts()

##overlap_dep_database = [x for x in research_audits_clean_dep_database.columns if x in research_audits_wgt_var.columns]

research_audits_merge_dep_database = research_audits_clean_dep_database.merge(research_audits_wgt_var, how = 'left', on = overlap)

### drop if final_wgt is 0 or null
research_audits_merge_deliver = research_audits_merge[research_audits_merge['final_wgt']!=0]
research_audits_merge_deliver = research_audits_merge_deliver[research_audits_merge_deliver['final_wgt'].notna()]

research_audits_merge_dep_database_deliver = research_audits_merge_dep_database[research_audits_merge_dep_database['final_wgt']!=0]
research_audits_merge_dep_database_deliver = research_audits_merge_dep_database_deliver[research_audits_merge_dep_database_deliver['final_wgt'].notna()]


research_audits_merge_deliver.to_csv('/REDACTED/fairness/code/rf/data/final_wgt_data/clean_rf_data.csv', index=False)
research_audits_merge_dep_database_deliver.to_csv('/REDACTED/fairness/code/rf/data/final_wgt_data/clean_rf_data_plus_dep_database.csv', index=False)



## how many eitc claimants did not respond?
research_audits_eic = research_audits_merge_dep_database[research_audits_merge_dep_database['eitc_amt']>0]

len(research_audits_eic[research_audits_eic['final_wgt']==0])
len(research_audits_eic[research_audits_eic.final_wgt.isna()])

research_audits_eic_no_res = research_audits_eic[research_audits_eic['final_wgt']==0 & research_audits_eic['final_wgt'].isna()]

research_audits_eic_no_res.groupby(['tax_yr'])['base_weight'].sum()
research_audits_eic.groupby(['tax_yr'])['base_weight'].sum()


sys.path.insert(1,'/REDACTED/fairness/code/rf/scripts/rf')
from gpd_test import *
get_plot_dict_new(acsource='return',
                eitc=True,
                activity_code = None,
                rand_seed=50,
                bootstrap=True,
                bootstrap_iters=100,
                model='reg',
                retrain=True,
                thresh=None,
                dep_database=True,
                write_test=False,
                one_fold_params=False,
                fold_for_params=None,
                datapath='/REDACTED/fairness/code/rf/data/final_wgt_data/',
                outdir='/REDACTED/data/modeled_refactor_temp/final_wgt/')


get_plot_dict_new(acsource='return',
                eitc=True,
                activity_code = None,
                rand_seed=50,
                bootstrap=True,
                bootstrap_iters=100,
                model='oracle',
                retrain=True,
                thresh=None,
                dep_database=True,
                write_test=False,
                one_fold_params=False,
                fold_for_params=None,
                datapath='/REDACTED/fairness/code/rf/data/final_wgt_data/',
                outdir='/REDACTED/data/modeled_refactor_temp/final_wgt/')
"""